'''
Using alignment exported from IGV, count # of CLCL junction supported reads
Usage:
python junction_counting.py
'''

import re
import sys
import pysam
import gzip
import argparse

class sub_read_info:
    def __init__(self):
        self.prev_start = 0
        self.prev_end = 0
        self.prev_chrom = ''
        self.prev_strand = ''
        self.start = 0
        self.end = 0
        self.chrom = ''
        self.strand = ''
        self.read_start = 0
        self.read_end = 0
        self.prev_read_end = 0
    
    def update_prev_subread(self):
        self.prev_chrom = self.chrom
        self.prev_strand = self.strand
        self.prev_read_end = self.read_end
        self.prev_start = self.start
        self.prev_end = self.end

    def update_current_subread(self, chrom, strand, start, end, read_start, read_end):
        self.chrom = chrom
        self.strand = strand
        self.read_start = read_start
        self.read_end = read_end
        if strand == '+':
            self.start = start
            self.end = end
        else:
            self.start = end
            self.end = start

class sv:
    def __init__(self, sv_type):
        self.support = 1
        self.sv_type = sv_type

    def increment(self):
        self.support += 1

def create_key(s, e):
    return s + ':' + str(e)

def filter_reads(bamFile):
    bamfile = pysam.AlignmentFile(bamFile, "rb")
    filter_name = list()
    for read in bamfile.fetch():
        if read.flag & 256 == 256:
            filter_name.append(read.query_name)

    bamfile.close()

    return filter_name

def call_deletion_cigar(read, junctions, chromosomes):

    ref_start = read.get_reference_positions()[0] + 1
    ref_length = 0
    chrom = chromosomes[read.reference_id]

    flag = 0
    length = 0
    start, end = 0, 0
    for pair in read.get_aligned_pairs():
        if pair[1] == None:
            continue
        if pair[0] == None and flag == 0:
            flag = 1
            length = 1
            start = pair[1] + 1
        elif pair[0] == None and flag == 1:
            length += 1
            end = pair[1] + 1
        elif pair[0] != None and flag == 1:
            flag = 0
            if length > 2000:
                key = (create_key(chrom, start), create_key(chrom, end)) if start < end else (create_key(chrom, end), create_key(chrom, start))
                if key in junctions:
                    junctions[key].increment()
                else:
                     junctions[key] = sv('DEL')
            length = 0
'''
    for cigar in read.cigartuples:
        if cigar[0] == 2 and cigar[1] > 2000:
            key = (create_key(chrom, ref_start + ref_length - 1), create_key(chrom, ref_start + ref_length + cigar[1] - 1))
            if key in junctions:
                junctions[key].increment()
            else:
                 junctions[key] = sv('DEL')
                 
            ref_length += cigar[1]

        elif cigar[0] == 0 or cigar[0] == 3:
            ref_length += cigar[1]
'''
def extract_information(bamFile, filter_name):
    ex_info = dict()
    junctions = dict()
    bamfile = pysam.AlignmentFile(bamFile, "rb")
    chromosomes = bamfile.references
    for read in bamfile.fetch():
        if read.query_name in filter_name: # discard secondary alignment
            continue

        if read.flag & 4 == 4 or read.mapping_quality < 30:
            continue
            
        for cigar in read.cigartuples:
            if cigar[0] == 2 and cigar[1] > 2000:
                call_deletion_cigar(read, junctions, chromosomes)
                break

        if 'SA' not in [tag[0] for tag in read.get_tags()]:
            continue

        read_length = 0
        read_start = 1
        strand = '-' if read.is_reverse else '+'
        
        for cigar in read.cigartuples:
            if cigar[0] == 0 or cigar[0] == 1:
                read_length += cigar[1]

        if not read.is_reverse:
            if read.cigartuples[0][0] == 4 or read.cigartuples[0][0] == 5:
                read_start += read.cigartuples[0][1]
        else:
            if read.cigartuples[-1][0] == 4 or read.cigartuples[-1][0] == 5:
                read_start += read.cigartuples[-1][1]
        
        reference_start = read.get_reference_positions()[0] + 1
        read_end = read_start + read_length - 1
        reference_end = read.get_reference_positions()[-1] + 1

        if read.query_name in ex_info:
            ex_info[read.query_name].append([chromosomes[read.reference_id], reference_start, reference_end, strand, read_start, read_end])
        else:
            ex_info[read.query_name] = [[chromosomes[read.reference_id], reference_start, reference_end, strand, read_start, read_end]]

    bamfile.close()

    with gzip.open('split_information.txt.gz', 'wt') as w:
        for k in ex_info:
            ex_info[k].sort(key=lambda x:x[4])
            print(k, file=w, end='\t')
            for i, item in enumerate(ex_info[k]):
                if i+1 < len(ex_info[k]):
                    print(','.join([str(i) for i in item]), file=w, end='\t')
                else:
                    print(','.join([str(i) for i in item]), file=w)

    return ex_info, junctions


def call_translocation(read, junctions):
    key1 = create_key(read.prev_chrom, read.prev_end)
    key2 = create_key(read.chrom, read.start)

    key = (key2, key1) if read.prev_chrom > read.chrom else (key1, key2)

    if key in junctions:
        junctions[key].increment()
    else:
        junctions[key] = sv('TRA')

def call_inversion(read, junctions):
    key1 = create_key(read.prev_chrom, read.prev_end)
    key2 = create_key(read.chrom, read.start)

    key = (key1, key2) if read.prev_end < read.start else (key2, key1)

    if key in junctions:
        junctions[key].increment()
    else:
        junctions[key] = sv('INV')

def call_deletion(read, junctions):
    key1 = create_key(read.prev_chrom, read.prev_end)
    key2 = create_key(read.chrom, read.start)

    key = (key1, key2) if read.prev_end < read.start else (key2, key1)

    if key in junctions:
        junctions[key].increment()
    else:
        junctions[key] = sv('DEL')

def call_duplication(read, junctions):
    if read.strand == '+':
        key = (create_key(read.prev_chrom, read.prev_end), create_key(read.chrom, read.start))
    else:
        key = (create_key(read.chrom, read.start), create_key(read.prev_chrom, read.prev_end))

    if key in junctions:
        junctions[key].increment()
    else:
        junctions[key] = sv('DUP')


def call_structural_variants(ex_info, junctions):
    for name in ex_info:

        if len(ex_info[name]) == 1:
            continue
            #print('Some bugs in ex_info', file=sys.stderr)

        read = sub_read_info()

        for i, item in enumerate(ex_info[name]):
            read.update_current_subread(item[0], item[3], item[1], item[2], item[4], item[5])
            if (i > 0 and abs(read.prev_read_end - read.read_start) > 300) or i == 0:
                read.update_prev_subread()
                continue

            if read.prev_chrom != read.chrom:
                call_translocation(read, junctions)

            elif read.prev_strand != read.strand:
                if abs(read.prev_end - read.start) > 2000:
                    call_inversion(read, junctions)

            elif (read.strand == '+' and read.start - read.prev_end > 2000) or (read.strand == '-' and read.prev_end - read.start > 2000):
                call_deletion(read, junctions)
            
            elif (read.strand == '+' and read.prev_end - read.start > 2000) or (read.strand == '-' and read.start - read.prev_end > 2000):
                call_duplication(read, junctions)

            read.update_prev_subread()

    with gzip.open('junctions.txt.gz', 'wt') as w:
        for key in junctions:
            print(key[0], key[1], junctions[key].sv_type, junctions[key].support, sep='\t', file=w)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--bam", help="PATH of BAM file")
    args = parser.parse_args()
    filter_name = filter_reads(args.bam)
    print("Done")
    ex_info, junctions = extract_information(args.bam, filter_name)
    #print(junctions)
    #sys.exit(1)
    print("Done")
    #ex_info, junctions = dict(), dict()
    call_structural_variants(ex_info, junctions)
